Seeing is believing

Using FlashTorch 🔦 to shine a light on what neural nets "see"


by Misa Ogura

Hello, I'm Misa 👋


  • Originally from Tokyo, now based in London
  • Cancer Cell Biologist, turned Software Engineer
  • Currently at BBC R&D
  • Co-founder of Women Driven Development
  • Women in Data Science London Ambassador

Feature visualisation


Image convolution & CNN 101


Kernel & convolution


Kernel: a small matrix used for blurring, sharpening, embossing, edge detection etc

Convolution: an operation to calculate weighted sum of local neibours

Example of convolution: detecting edges


Make Sobel kernels and detect edges with OpenCV


In [2]:
fig = plt.figure(figsize=(14, 3))
ax = fig.add_subplot(1, 3, 1, xticks=[], yticks=[])
ax.imshow(image, cmap='gray')
ax.set_title('Original image')

sobel_x = np.array([[ -1, 0, 1], 
                    [ -2, 0, 2], 
                    [ -1, 0, 1]])

sobel_y = np.array([[ -1, -2, -1], 
                    [ 0, 0, 0], 
                    [ 1, 2, 1]])

kernels = {'Sobel x': sobel_x, 'Sobel y': sobel_y}

for i, (title, kernel) in enumerate(kernels.items()):
    filtered_img = cv2.filter2D(image, -1, kernel)
    ax = fig.add_subplot(1, 3, i+2, xticks=[], yticks=[])
    ax.imshow(filtered_img, cmap='gray')
    ax.set_title(title)

Typical CNN architecture


Kernels weights are learnt during the training to extract relevant features from input images.

CNN Visualisation with FlashTorch - 1

Visualising saliency maps


Saliency maps


Install FlashTorch & load an image


First things first...

$ pip install flashtorch
In [4]:
from flashtorch.utils import load_image

image = load_image('../../examples/images/great_grey_owl_01.jpg')

plt.imshow(image)
plt.title('Original image')
plt.axis('off');

Convert to a torch tensor


In [5]:
from flashtorch.utils import apply_transforms

input_ = apply_transforms(image)

print(f'Before: {type(image)}')
print(f'After: {type(input_)}, {input_.shape}')
Before: <class 'PIL.Image.Image'>
After: <class 'torch.Tensor'>, torch.Size([1, 3, 224, 224])

Let's visualise the input


In [6]:
from flashtorch.utils import format_for_plotting

plt.imshow(format_for_plotting(input_))
plt.title('Input tensor')
plt.axis('off');
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Let's visualise the input - take two


In [7]:
from flashtorch.utils import denormalize

plt.imshow(format_for_plotting(denormalize(input_)))
plt.title('Input tensor')
plt.axis('off');

Load a pre-trained model & create a Backprop object


In [8]:
from flashtorch.saliency import Backprop

model = models.alexnet(pretrained=True)

backprop = Backprop(model)
Signature:

    backprop.calculate_gradients(input_, target_class=None, take_max=False)

Retrieve the class index


In [9]:
from flashtorch.utils import ImageNetIndex 

imagenet = ImageNetIndex()
target_class = imagenet['great grey owl']

print(target_class)
24

Finally, calculate the gradients w.r.t the input


In [10]:
gradients = backprop.calculate_gradients(input_, target_class)

max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

print(type(gradients), gradients.shape)
print(type(max_gradients), max_gradients.shape)
<class 'torch.Tensor'> torch.Size([3, 224, 224])
<class 'torch.Tensor'> torch.Size([1, 224, 224])

Let's inspect gradients


In [11]:
from flashtorch.utils import visualize

visualize(input_, gradients, max_gradients)

Pixels where the animal is present have the strongest positive effects.

But it's quite noisy...

Guided backprop to the rescue!


TODO: add explanation of guided backprop

Calculate the gradients with guided backprop


In [12]:
guided_gradients = backprop.calculate_gradients(input_, target_class, guided=True)

max_guided_gradients = backprop.calculate_gradients(input_, target_class, take_max=True, guided=True)

visualize(input_, guided_gradients, max_guided_gradients)

Now that's much less noisy!

Pixels around the head and eyes have the strongest positive effects.

What about a jay?


In [14]:
visualize(input_, guided_gradients, max_guided_gradients)

Or an oystercatcher?


In [16]:
visualize(input_, guided_gradients, max_guided_gradients)

CNN Visualisation with FlashTorch - 2

Gaining additional insights on transfer learning


Transfer Learning


  • A model developed for a task is reused as a starting point for another task

  • Pre-trained models often used in computer visions & natural language processing tasks

  • Save compute & time resources

Flower Classifier


From: Densenet model, pre-trained on ImageNet (1000 classes)

To: Flower classifier to recognise 102 species of flowers, using a dataset from VGG group.

Load a target image


In [17]:
image = load_image('../../examples/images/foxglove.jpg')

plt.imshow(image)
plt.title('Foxglove')
plt.axis('off');

Pre-trained model (no additional training!)


In [20]:
guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)

guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
/Users/misao/Projects/personal/flashtorch/flashtorch/saliency/backprop.py:93: UserWarning: The predicted class does not equal the
                target class. Calculating the gradient with respect to the
                predicted class.
  predicted class.'''))

Trained model


In [21]:
trained_model = create_model('../../models/flower_classification_transfer_learning.pt')

backprop = Backprop(trained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)

guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)